0dd97230690f299f4aa86a4c510de88ad3d6b63d,src/main/java/ml/shifu/shifu/core/dtrain/nn/NNWorker.java,NNWorker,load,#GuaguaWritableAdapter#GuaguaWritableAdapter#WorkerContext#,53

Before Change


        // if fixInitialInput = false, we only compare random value with baggingSampleRate to avoid parsing data.
        // if fixInitialInput = true, we should use hashcode after parsing.
        double baggingSampleRate = super.modelConfig.getBaggingSampleRate();
        if(!super.modelConfig.isFixInitialInput() && Double.compare(Math.random(), baggingSampleRate) >= 0) {
            // for negative tags, do sampleNegOnly logic
            if(modelConfig.getTrain().getSampleNegOnly()) {
                if(modelConfig.isRegression() && Double.compare(ideal[0] + 0d, 0d) == 0) {
                    return;
                }
            } else {
                return;// normal sampling
            }
        }
        // if fixInitialInput = true, we should use hashcode to sample.
        long longBaggingSampleRate = Double.valueOf(baggingSampleRate * 100).longValue();
        if(super.modelConfig.isFixInitialInput() && hashcode % 100 >= longBaggingSampleRate) {
            // for negative tags, do sampleNegOnly logic
            if(modelConfig.getTrain().getSampleNegOnly()) {
                if(modelConfig.isRegression() && Double.compare(ideal[0] + 0d, 0d) == 0) {
                    return;
                }
            } else {
                return;// normal sampling
            }
        }

        // count stats after sampling
        super.sampleCount += 1;

        FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));

        if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
            // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
            pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
        } else {
            pair.setSignificance(significance);
        }
        boolean isTesting = false;
        if(workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
            isTesting = (Boolean) workerContext.getAttachment();
        }
        addDataPairToDataSet(hashcode, pair, isTesting);
    }

    /*

After Change


        }

        // if only sample negative, no matter bagging or replacement, do sampling here.
        if(modelConfig.getTrain().getSampleNegOnly() // sample negative enabled
                && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
                        .isOneVsAll())) // regression or onevsall
                && Double.compare(ideal[0] + 0.01d, 0d) == 0 // negative record
                && (!this.modelConfig.isFixInitialInput() && Double.compare(Math.random(),
                        this.modelConfig.getBaggingSampleRate()) >= 0)) {
            return;
        }
        if(modelConfig.getTrain().getSampleNegOnly()// sample negative enabled
                && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
                        .isOneVsAll()))// regression or onevsall
                && (Double.compare(ideal[0] + 0.01d, 0d) == 0 // negative record
                        && this.modelConfig.isFixInitialInput() && hashcode % 100 >= Double.valueOf(
                        this.modelConfig.getBaggingSampleRate() * 100).longValue())) {
            return;
        }

        FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));

        // up sampling logic, just add more weights while bagging sampling rate is still not changed
        if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
            // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
            pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
        } else {
            pair.setSignificance(significance);
        }

        boolean isValidation = false;
        if(workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
            isValidation = (Boolean) workerContext.getAttachment();
        }

        boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);

        // do bagging sampling only for training data,
        if(isInTraining) {
            float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
            if(isPositive(pair.getIdealArray()[0])) {
                this.positiveSelectedTrainCount += subsampleWeights * 1L;
            } else {
                this.negativeSelectedTrainCount += subsampleWeights * 1L;